#%%
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

import numpy as np
import tensorflow as tf

import gym

# output tensor
from network.Network_Definition import probability, newGrads, opt_updateGrads
# input tensor, placeholder
from network.Network_Definition import observations, input_y, advantages, batchGrad
from network.Network_Definition import D

# hyperparameters
# batch_size 为每次训练网络所选择的batch_size, 这里batch_size等于episodes数
batch_size = 25 # every how many episodes to do a param update?
gamma = 0.99 # discount factor for reward

env = gym.make('CartPole-v0')

xs,ys,drs = [],[],[]
reward_sum = 0           # 一个batch内的所有episodes所能获得的reward总和
episode_number = 1       #当前episode
total_episodes = 10000   #总共的episodes


def discount_rewards(r):
    """ take 1D float array of rewards and compute discounted reward """
    """ 计算一个episode内不同state下所选动作的累计折扣回报 """
    discounted_r = np.zeros_like(r)
    running_add = 0
    for t in reversed(range(r.size)):
        running_add = running_add * gamma + r[t]
        discounted_r[t] = running_add
    return discounted_r
        

# Launch the graph
with tf.Session() as sess:
    # rendering = False   #绘图flag,ture为进行绘图，false为不绘图
    init = tf.global_variables_initializer()
    sess.run(init)

    # Reset the gradient placeholder. We will collect gradients in 
    # gradBuffer until we are ready to update our policy network. 
    gradBuffer = sess.run(tf.trainable_variables())
    # 获得图内的所有trainable_variables, copy到numpy中后全部置0，并将其作为梯度的存放位置
    for ix,grad in enumerate(gradBuffer):
        gradBuffer[ix] = grad * 0

    # 环境env初始化
    observation = env.reset() # Obtain an initial observation of the environment

    while episode_number <= total_episodes:    # episode_number为当前episode数
        # Rendering the environment slows things down,
        # so let's only look at it once our agent is doing a good job.
        """ 
        # reward_sum为一个batch内所有episodes所获得的reward总和    
        # reward_sum/batch_zie 为平均一个episode内所能获得的reward总和
        if reward_sum/batch_size > 100 or rendering == True : 
            env.render()
            rendering = True
        """
            
        # Make sure the observation is in a shape the network can handle.
        x = np.reshape(observation,[1,D])
        
        # Run the policy network and get an action to take.
        """ probability为选择动作input_y=0的概率，动作input_y=1的选择概率为 1-probability """
        tfprob = sess.run(probability,feed_dict={observations: x})    # 选择动作input_y=0的概率
        # 以概率probability选择动作0， 1-probability的概率选择动作1
        y = 0 if np.random.uniform() < tfprob else 1

        xs.append(x)  # 将一个episode内的observation加入列表
        ys.append(y)  # 将一个episode内的action, y加入列表

        # step the environment and get new measurements
        observation, reward, done, info = env.step(y)
        reward_sum += reward

        drs.append(reward) # record reward (has to be done after we call step() to get reward for previous action)

        if done: 
            episode_number += 1
            # stack together all inputs, hidden states, action gradients, and rewards for this episode
            epx = np.vstack(xs)
            epy = np.vstack(ys)
            epr = np.vstack(drs)
            xs,ys,drs = [],[],[] # reset array memory

            # compute the discounted reward backwards through time
            discounted_epr = discount_rewards(epr)
            # size the rewards to be unit normal (helps control the gradient estimator variance)
            discounted_epr -= np.mean(discounted_epr)
            discounted_epr /= np.std(discounted_epr)
            
            # Get the gradient for this episode, and save it in the gradBuffer
            # 通过对一个episode内所有state,action,advantage以及选择action对应的概率进行计算获得loss
            # 并通过该loss获得网络参数的梯度Grads
            # loglik = tf.log(input_y * (input_y - probability) + (1 - input_y) * (input_y + probability))
            # loss = -tf.reduce_mean(loglik * advantages)
            tGrad = sess.run(newGrads, feed_dict={observations: epx, input_y: epy, advantages: discounted_epr})

            # 将根据一个episode获得的梯度加和到gradeBuffer中,
            # 当batch_size个episodes的梯度加和到gradeBuffer中后, 再对网络的参数进行更新
            # 从而达到通过batch_size个样本来对网络进行更新
            for ix,grad in enumerate(tGrad):
                gradBuffer[ix] += grad
                
            # If we have completed enough episodes, then update the policy network with our gradients.
            # 如果进行了batch_size个episodes的训练则进行下列操作
            if episode_number % batch_size == 0:
                # batch_size个episodes的训练得到梯度，将其更新到网络参数中
                sess.run(opt_updateGrads, feed_dict=dict(zip(batchGrad, gradBuffer)) )
                for ix,grad in enumerate(gradBuffer):
                    # 梯度更新到网络参数中后，将batch_size个episodes的梯度总和的变量清零
                    gradBuffer[ix] = grad * 0
                
                # Give a summary of how well our network is doing for each batch of episodes.
                #running_reward = reward_sum if running_reward is None else running_reward * 0.99 + reward_sum * 0.01
                print('Average reward for episode %d : %f.' % (episode_number,reward_sum/batch_size))
                
                if reward_sum/batch_size > 200: 
                    print("Task solved in",episode_number,'episodes!')
                    break
                # 将batch_size个episodes训练好的梯度更新到网络参数后开始下一个batch的计算
                # 将batch_size个episodes获得的奖励总和清零
                reward_sum = 0

            # 重置环境变量env, 开始下一个episode的计算
            observation = env.reset()
        
